import asyncio
import dataclasses
import json
import logging
import uuid
from base64 import b64encode
from datetime import timedelta
from typing import Any

import elasticapm
from asgiref.sync import sync_to_async
from django.conf import settings
from django.core.exceptions import ValidationError
from django.urls import reverse
from django.utils import dateparse, timezone
from lxml import etree

from sysreptor.pentests import cvss
from sysreptor.pentests.fielddefinition.sort import group_findings, sort_findings
from sysreptor.pentests.models import (
    Language,
    NoteType,
    PentestProject,
    ProjectMemberInfo,
    ProjectNotebookPage,
    ProjectType,
    UserNotebookPage,
)
from sysreptor.pentests.rendering import tasks
from sysreptor.pentests.rendering.error_messages import (
    ErrorMessage,
    MessageLevel,
    MessageLocationInfo,
    MessageLocationType,
)
from sysreptor.pentests.rendering.render_utils import RenderStageResult
from sysreptor.users.models import PentestUser
from sysreptor.utils.configuration import configuration
from sysreptor.utils.fielddefinition.types import (
    BaseField,
    CweField,
    EnumChoice,
    FieldDataType,
    FieldDefinition,
    ObjectField,
)
from sysreptor.utils.fielddefinition.utils import (
    HandleUndefinedFieldsOptions,
    ensure_defined_structure,
    iterate_fields,
    set_value_at_path,
)
from sysreptor.utils.logging import log_timing
from sysreptor.utils.utils import copy_keys, get_key_or_attr, merge

log = logging.getLogger(__name__)


def format_template_field_object(value: dict, definition: FieldDefinition|ObjectField, members: list[dict | ProjectMemberInfo] | None = None, require_id=False):
    out = value | ensure_defined_structure(value=value, definition=definition)
    for f in definition.fields:
        out[f.id] = format_template_field(value=out.get(f.id), definition=f, members=members)

    if require_id and 'id' not in out:
        out['id'] = str(uuid.uuid4())
    return out


def format_template_field_user(value: ProjectMemberInfo | str | uuid.UUID | None, members: list[dict | ProjectMemberInfo | PentestUser] | None = None):
    def format_user(u: ProjectMemberInfo | dict | None):
        if not u:
            return None
        return copy_keys(
            u.user if isinstance(u, ProjectMemberInfo) else u,
            ['id', 'name', 'title_before', 'first_name', 'middle_name', 'last_name', 'title_after', 'email', 'phone', 'mobile']) | \
            {'roles': sorted(set(filter(None, get_key_or_attr(u, 'roles', []))), key=lambda r: {
                             'lead': 0, 'pentester': 1, 'reviewer': 2}.get(r, 10))}

    if isinstance(value, ProjectMemberInfo|PentestUser|dict|None):
        return format_user(value)
    elif isinstance(value, str|uuid.UUID) and (u := next(filter(lambda i: str(get_key_or_attr(i, 'id')) == str(value), members or []), None)):
        return format_user(u)
    elif isinstance(value, str|uuid.UUID):
        try:
            return format_user(ProjectMemberInfo(user=PentestUser.objects.get(id=value), roles=[]))
        except (PentestUser.DoesNotExist, ValidationError):
            return None
    else:
        return None


def format_template_field_cwe(value: str|None, parents=True):
    cwe_definition = next(filter(lambda c: value == f"CWE-{c['id']}", CweField.cwe_definitions()), {})
    out = {
        'id': None,
        'name': None,
        'description': None,
        'value': value,
        'parent': None,
    } | cwe_definition
    out['parent'] = f'CWE-{out["parent"]}' if out.get('parent') else None
    if parents:
        out['parents'] = []
        parent_cwe = out.get('parent')
        while parent_cwe:
            parent = format_template_field_cwe(value=parent_cwe, parents=False)
            out['parents'].append(parent)
            parent_cwe = parent.get('parent')
    return out


def format_template_field(value: Any, definition: BaseField, members: list[dict | ProjectMemberInfo] | None = None):
    value_type = definition.type
    if value_type == FieldDataType.ENUM:
        return dataclasses.asdict(next(filter(lambda c: c.value == value, definition.choices), EnumChoice(value='', label='')))
    elif value_type == FieldDataType.CVSS:
        score_metrics = cvss.calculate_metrics(value)
        return score_metrics | {
            'vector': value,
            'score': str(round(score_metrics["final"]["score"], 2)),
            'level': cvss.level_from_score(score_metrics["final"]["score"]).value,
            'level_number': cvss.level_number_from_score(score_metrics["final"]["score"]),
        }
    elif value_type == FieldDataType.CWE:
        return format_template_field_cwe(value)
    elif value_type == FieldDataType.JSON:
        try:
            return json.loads(value)
        except (TypeError, json.JSONDecodeError):
            return None
    elif value_type == FieldDataType.USER:
        return format_template_field_user(value, members=members)
    elif value_type == FieldDataType.LIST:
        return [format_template_field(value=e, definition=definition.items, members=members) for e in value]
    elif value_type == FieldDataType.OBJECT:
        return format_template_field_object(value=value, definition=definition, members=members)
    else:
        return value


def format_template_data(
        data: dict, project_type: ProjectType, imported_members: list[dict] | None = None,
        override_finding_order=False, additional_data: dict|None = None,
):
    members = [format_template_field_user(u, members=imported_members) for u in data.get('pentesters', []) + (imported_members or [])]
    data['report'] = format_template_field_object(
        value=ensure_defined_structure(
            value=data.get('report', {}),
            definition=project_type.all_report_fields_obj,
            handle_undefined=HandleUndefinedFieldsOptions.FILL_DEFAULT),
        definition=project_type.all_report_fields_obj,
        members=members,
        require_id=True)
    data['sections'] = {
        s.get('id', ''): {
            'id': s.get('id', ''),
            'label': s.get('label', ''),
            **copy_keys(data['report'], [f.get('id') for f in s.get('fields', [])]),
        } for s in project_type.report_sections
    }
    data['findings'] = sort_findings(findings=[
        format_template_field_object(
            value={
                'id': uuid.uuid4(),
                'created': timezone.now().isoformat(),
            } | (f if isinstance(f, dict) else {}) | ensure_defined_structure(
                value=f,
                definition=project_type.finding_fields_obj,
                handle_undefined=HandleUndefinedFieldsOptions.FILL_DEFAULT),
            definition=project_type.finding_fields_obj,
            members=members,
            require_id=True)
        for f in data.get('findings', [])],
        project_type=project_type, override_finding_order=override_finding_order)
    data['findings'] = [f | {'order': fidx + 1} for fidx, f in enumerate(data['findings'])]
    data['finding_groups'] = group_findings(findings=data['findings'], project_type=project_type, override_finding_order=override_finding_order)
    del data['findings']
    data['pentesters'] = sorted(
        members,
        key=lambda u: (0 if 'lead' in u.get('roles', []) else 1 if 'pentester' in u.get(
            'roles', []) else 2 if 'reviewer' in u.get('roles', []) else 10, u.get('username')),
    )
    data = merge(data, additional_data or {})
    return data


def format_project_template_data(project: PentestProject, project_type: ProjectType | None = None, additional_data: dict|None = None):
    if not project_type:
        project_type = project.project_type
    data = {
        'report': {
            'id': str(project.id),
            'created': project.created.isoformat(),
            'language': project.language,
            'tags': project.tags,
            **project.data,
        },
        'findings': [{
            'id': str(f.finding_id),
            'created': f.created.isoformat(),
            'order': f.order,
            **f.data,
        } for f in project.findings.all()],
        'pentesters': [u for u in project.members.all()],
    }
    return format_template_data(
        data=data,
        project_type=project_type,
        imported_members=project.imported_members,
        override_finding_order=project.override_finding_order,
        additional_data=additional_data,
    )


@sync_to_async
def aformat_project_template_data(project: PentestProject, project_type: ProjectType | None = None, additional_data: dict|None = None):
    return format_project_template_data(project=project, project_type=project_type, additional_data=additional_data)


async def get_celery_result_async(task):
    try:
        task.on_ready.then(print)
        while not task.ready():  # noqa: ASYNC110
            await asyncio.sleep(0.1)
        if isinstance(task.result, Exception):
            raise task.result
        return task.result
    except asyncio.CancelledError:
        try:
            await sync_to_async(task.revoke)(terminate=True, wait=False)
        except Exception:  # noqa: S110
            pass # Ignore errors
        raise


async def _render_pdf_task_async(**kwargs):
    timeout = timedelta(seconds=settings.PDF_RENDERING_TIME_LIMIT + 5).total_seconds() if settings.PDF_RENDERING_TIME_LIMIT else None

    try:
        async with asyncio.timeout(timeout):
            if settings.CELERY_TASK_ALWAYS_EAGER:
                # Do not use celery when tasks are executed eagerly in the same process
                # Use async instead to be able to cancel tasks.
                # sync_to_async functions are not cancelled because the ThreadPoolExecutor does not support cancellation.
                # Tasks continue running in background, even when the asyncio coroutine is already cancelled.
                res = await tasks.render_pdf_task_async(**kwargs)
            else:
                task = await sync_to_async(tasks.render_pdf_task_celery.delay)(**kwargs)
                res = await get_celery_result_async(task)
        return RenderStageResult.from_dict(res)
    except asyncio.CancelledError:
        logging.info('PDF rendering task cancelled')
        raise
    except TimeoutError as ex:
        logging.error('PDF rendering task timeout')
        raise TimeoutError('PDF rendering timeout') from ex


@elasticapm.async_capture_span()
@log_timing(log_start=True, log_detailed_timings=True)
async def render_pdf_task(
    project_type: ProjectType, report_template: str, report_styles: str, data: dict,
    password: str | None = None, can_compress_pdf: bool = False, project: PentestProject | None = None,
    output=None, html=None, timings=None,
) -> RenderStageResult:
    res = RenderStageResult(timings=timings or {})

    @sync_to_async()
    def format_resources():
        resources = {}
        resources |= {'/assets/name/' + a.name: b64encode(a.file.read()).decode() for a in project_type.assets.all()}
        if project:
            resources |= {'/images/name/' + i.name: b64encode(i.file.read()).decode() for i in project.images.all() if project.is_file_referenced(i, sections=True, findings=True, notes=False)}
        return resources

    with res.add_timing('collect_data'):
        resources = await format_resources()

    res.timings.setdefault('queue', 0.0)
    before_task_start = timezone.now()
    timing_before_task_total = sum(res.timings.values())
    with res.add_timing('task_total'):
        res_pdf = await _render_pdf_task_async(
            template=report_template,
            styles=report_styles,
            data=data,
            language=project.language if project else project_type.language,
            password=password,
            compress_pdf=can_compress_pdf and (await configuration.aget('COMPRESS_PDFS')),
            accessible_pdf=(await configuration.aget('GENERATE_ACCESSIBLE_PDFS')),
            output=output,
            html=html,
            resources=resources,
        )
    res |= res_pdf
    if (task_start_time := res.other.pop('task_start_time', None)):
        # use datetimes instead of perf_counter, because the task might be executed by a worker on a different machine and perf_counter is not synchronized
        res.timings['queue'] += (dateparse.parse_datetime(task_start_time) - before_task_start).total_seconds()
    res.timings['other'] = res.timings.get('other', 0) + max(0, res.timings.pop('task_total') + timing_before_task_total - sum(v for v in res.timings.values()))

    # Set message location info to ProjectType (if not available)
    res.messages = [
        (m if m.location else dataclasses.replace(m, location=MessageLocationInfo(type=MessageLocationType.DESIGN, id=project_type.id, name=project_type.name)))
        for m in res.messages
    ]
    return res


@elasticapm.async_capture_span()
async def render_project_markdown_fields_to_html(project: PentestProject, request) -> dict:
    """
    Render the all markdown fields of a project to HTML and return the project data with the rendered HTML fields.
    Markdown rendering is done in Chromium similar to the PDF rendering.
    This is required because our markdown renderer (with custom extensions) is implemented in JS which cannot be used in Python
    and we are able to evaluate Vue template language embedded in markdown fields.
    """

    # Collect all markdown fields
    markdown_fields = {}
    async for s in project.sections.all():
        for (path, value, definition) in iterate_fields(value=s.data, definition=project.project_type.all_report_fields_obj, path=('sections', str(s.section_id))):
            if definition.type == FieldDataType.MARKDOWN:
                markdown_fields[json.dumps(path)] = value
    async for f in project.findings.all():
        for (path, value, definition) in iterate_fields(value=f.data, definition=project.project_type.finding_fields_obj, path=('findings', str(f.finding_id))):
            if definition.type == FieldDataType.MARKDOWN:
                markdown_fields[json.dumps(path)] = value

    # Render markdown fields to HTML
    data = await aformat_project_template_data(project=project) | {
        'markdown_fields': markdown_fields,
    }
    res = await render_pdf_task(
        project_type=project.project_type,
        report_template="""<markdown v-for="([id, text]) in Object.entries(data.markdown_fields)" :id="id" :text="text" />""",
        report_styles="",
        data=data,
        output='html',
    )
    if not res.pdf:
        return res.to_dict()

    def format_output():
        from sysreptor.pentests.serializers.project import PentestProjectDetailSerializer

        # Extract markdown fields from HTML (maybe with lxml)
        html_tree = etree.HTML(res.pdf.decode())
        rendered_md_nodes = html_tree.find('body/div').getchildren()
        for mdf in rendered_md_nodes:
            mdf_id = mdf.attrib.get('id')
            if mdf_id in markdown_fields:
                markdown_fields[mdf_id] = ''.join(map(lambda e: etree.tostring(e, method="html", pretty_print=True).decode(), mdf.getchildren()))

        # Serialize project to dict and replace markdown fields with HTML in dict
        result = PentestProjectDetailSerializer(instance=project, context={'request': request}).data
        for path_str, html in markdown_fields.items():
            path = json.loads(path_str)
            if path[0] == 'sections':
                section_data = next(filter(lambda s: s['id'] == path[1], result['sections']))['data']
                set_value_at_path(section_data, path[2:], html)
            elif path[0] == 'findings':
                finding_data = next(filter(lambda f: f['id'] == path[1], result['findings']))['data']
                set_value_at_path(finding_data, path[2:], html)

        return {
            'result': result,
            'messages': res.to_dict()['messages'],
        }

    try:
        return await sync_to_async(format_output)()
    except Exception:
        log.exception('Error while formatting output')
        res.messages.append(ErrorMessage(
            level=MessageLevel.ERROR,
            message='Error while formatting output',
        ))
        return res.to_dict()


@elasticapm.async_capture_span()
async def render_note_to_pdf(notes: list[ProjectNotebookPage] | list[UserNotebookPage], request=None) -> RenderStageResult:
    if not notes:
        return RenderStageResult(messages=[ErrorMessage(
            level=MessageLevel.ERROR,
            message='No notes provided for rendering',
        )])

    is_project_note = isinstance(notes[0], ProjectNotebookPage)
    parent_obj = notes[0].project if is_project_note else notes[0].user

    res = RenderStageResult()
    with res.add_timing('collect_data'):
        # Prevent sending unreferenced images to rendering task to reduce memory consumption
        resources = {}
        async for i in parent_obj.images.all():
            if any(n.is_file_referenced(i) for n in notes):
                resources['/images/name/' + i.name] = b64encode(i.file.read()).decode()

        # Rewrite file links to absolute URLs
        if request:
            async for f in parent_obj.files.only('id', 'name'):
                for n in notes:
                    if n.is_file_referenced(f):
                        if is_project_note:
                            absolute_file_url = request.build_absolute_uri(reverse('uploadedprojectfile-retrieve-by-name', kwargs={'project_pk': n.project.id, 'filename': f.name}))
                        else:
                            absolute_file_url = request.build_absolute_uri(reverse('uploadedusernotebookfile-retrieve-by-name', kwargs={'pentestuser_pk': n.user.id, 'filename': f.name}))
                        n.text = n.text.replace(f'/files/name/{f.name}', absolute_file_url)

    res |= await _render_pdf_task_async(
        template="""
        <section v-for="note, idx in data.notes">
            <h1>{{ note.title }}</h1>
            <markdown :text="note.text" />
            <pagebreak v-if="idx < data.notes.length - 1" />
        </section>""",
        styles="""@import "/assets/global/base.css";""",
        data={
            'notes': [{
                'id': str(n.note_id),
                'title': n.title,
                'text': n.text if n.type == NoteType.TEXT else '',
            } for n in notes],
        },
        language=parent_obj.language if is_project_note else Language.ENGLISH_US,
        resources=resources,
    )
    return res


async def render_pdf(
    project: PentestProject, project_type: ProjectType | None = None,
    report_template: str | None = None, report_styles: str | None = None,
    additional_data: dict | None =None,
    password: str | None = None, can_compress_pdf: bool = False, output: str | None = None,
) -> RenderStageResult:
    if not project_type:
        project_type = project.project_type
    if not report_template:
        report_template = project_type.report_template
    if not report_styles:
        report_styles = project_type.report_styles

    res = RenderStageResult()
    with res.add_timing('collect_data'):
        data = await aformat_project_template_data(project=project, project_type=project_type, additional_data=additional_data)
    return await render_pdf_task(
        project=project,
        project_type=project_type,
        report_template=report_template,
        report_styles=report_styles,
        data=data,
        password=password,
        can_compress_pdf=can_compress_pdf,
        timings=res.timings,
        output=output,
    )


async def render_pdf_preview(project_type: ProjectType, report_template: str, report_styles: str, report_preview_data: dict) -> RenderStageResult:
    res = RenderStageResult()
    with res.add_timing('collect_data'):
        preview_data = report_preview_data.copy()
        data = await sync_to_async(format_template_data)(data=preview_data, project_type=project_type)

    return await render_pdf_task(
        project_type=project_type,
        report_template=report_template,
        report_styles=report_styles,
        data=data,
        timings=res.timings,
    )
